{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/function_space_linearization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9uPYkWOcghJm",
    "pycharm": {}
   },
   "source": [
    "##### Copyright 2019 Google LLC.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YDnknGorgv2O",
    "pycharm": {}
   },
   "source": [
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "https://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8KPv0bOW6UCi",
    "pycharm": {}
   },
   "source": [
    "#### Import & Utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cxFbqXZKhGW0",
    "pycharm": {}
   },
   "source": [
    "Install JAX, Tensorflow Datasets, and Neural Tangents.\n",
    "\n",
    "The first line specifies the version of jaxlib that we would like to import. Note, that \"cp36\" species the version of python (version 3.6) used by JAX. Make sure your colab kernel matches this version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "g_gSbMyUhF92",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
        "!pip install -q --upgrade pip\n",
        "!pip install -q --upgrade jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
        "!pip install -q tensorflow-datasets\n",
        "!pip install -q git+https://www.github.com/google/neural-tangents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "knIftr57X055",
    "pycharm": {}
   },
   "source": [
    "Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8D0i89hRmNoC",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "from jax import jit\n",
    "from jax import grad\n",
    "from jax import random\n",
    "\n",
    "import jax.numpy as jnp\n",
    "from jax.nn import log_softmax\n",
    "from jax.example_libraries import optimizers\n",
    "\n",
    "import tensorflow_datasets as tfds\n",
    "\n",
    "import neural_tangents as nt\n",
    "from neural_tangents import stax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_bbZz-nWX4Hj",
    "pycharm": {}
   },
   "source": [
    "Define helper functions for processing data and defining a vanilla momentum optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-W1ws1B-6_vq",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "def process_data(data_chunk):\n",
    "  \"\"\"Flatten the images and one-hot encode the labels.\"\"\"\n",
    "  image, label = data_chunk['image'], data_chunk['label']\n",
    "\n",
    "  samples = image.shape[0]\n",
    "  image = jnp.array(jnp.reshape(image, (samples, -1)), dtype=jnp.float32)\n",
    "  image = (image - jnp.mean(image)) / jnp.std(image)\n",
    "  label = jnp.eye(10)[label]\n",
    "\n",
    "  return {'image': image, 'label': label}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ik27L4izDK9s",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "@optimizers.optimizer\n",
    "def momentum(learning_rate, momentum=0.9):\n",
    "  \"\"\"A standard momentum optimizer for testing.\n",
    "\n",
    "  Different from `jax.example_libraries.optimizers.momentum` (Nesterov).\n",
    "  \"\"\"\n",
    "  learning_rate = optimizers.make_schedule(learning_rate)\n",
    "  def init_fn(x0):\n",
    "    v0 = jnp.zeros_like(x0)\n",
    "    return x0, v0\n",
    "  def update_fn(i, g, state):\n",
    "    x, velocity = state\n",
    "    velocity = momentum * velocity + g\n",
    "    x = x - learning_rate(i) * velocity\n",
    "    return x, velocity\n",
    "  def get_params(state):\n",
    "    x, _ = state\n",
    "    return x\n",
    "  return init_fn, update_fn, get_params\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "32Wvhil9X8IK",
    "pycharm": {}
   },
   "source": [
    "# Function Space Linearization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JJ_zDKsKcDB-",
    "pycharm": {}
   },
   "source": [
    "Create MNIST data pipeline using TensorFlow Datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5llaSqZW4Et3",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "dataset_size = 64\n",
    "\n",
    "ds_train, ds_test = tfds.as_numpy(\n",
    "    tfds.load('mnist:3.*.*', split=['train[:%d]' % dataset_size,\n",
    "                                    'test[:%d]' % dataset_size],\n",
    "              batch_size=-1)\n",
    ")\n",
    "\n",
    "train = process_data(ds_train)\n",
    "test = process_data(ds_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ajz_oTOw72v8",
    "pycharm": {}
   },
   "source": [
    "Setup some experiment parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UtjfeaYC72Gs",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "learning_rate = 1e0\n",
    "training_steps = jnp.arange(1000)\n",
    "print_every = 100.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1-nKR--j5p2C",
    "pycharm": {}
   },
   "source": [
    "Create a Fully-Connected Network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wIbfrdzq5pLZ",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "init_fn, f, _ = stax.serial(\n",
    "    stax.Dense(512, 1., 0.05),\n",
    "    stax.Erf(),\n",
    "    stax.Dense(10, 1., 0.05))\n",
    "\n",
    "key = random.PRNGKey(0)\n",
    "_, params = init_fn(key, (-1, 784))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "c9zgKt9B8NBt",
    "pycharm": {}
   },
   "source": [
    "Construct the NTK."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bU6ccJM_8LWt",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "ntk = nt.batch(nt.empirical_ntk_fn(f, vmap_axes=0),\n",
    "               batch_size=64, device_count=0)\n",
    "\n",
    "g_dd = ntk(train['image'], None, params)\n",
    "g_td = ntk(test['image'], train['image'], params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jdR-lIW11Vbj",
    "pycharm": {}
   },
   "source": [
    "Now that we have the NTK and a network we can compare against a number of different dynamics. Remember to reinitialize the network and NTK if you want to try a different dynamics."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hVesciX61bGb",
    "pycharm": {}
   },
   "source": [
    "## Gradient Descent, MSE Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lrp9YNCt7nCj",
    "pycharm": {}
   },
   "source": [
    "Create a optimizer and initialize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "J-8i_4KD7o5s",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)\n",
    "state = opt_init(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NspVdDOU8mhk",
    "pycharm": {}
   },
   "source": [
    "Create an MSE loss and a gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "z6L-LzyF8qLW",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "loss = lambda fx, y_hat: 0.5 * jnp.mean((fx - y_hat) ** 2)\n",
    "grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f57Teh1317hn",
    "pycharm": {}
   },
   "source": [
    "Create an MSE predictor and compute the function space values of the network at initialization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7UH_uOxz16w2",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "predictor = nt.predict.gradient_descent_mse(g_dd, train['label'],\n",
    "                                            learning_rate=learning_rate)\n",
    "fx_train = f(params, train['image'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rWROOyCZ9u6N",
    "pycharm": {}
   },
   "source": [
    "Train the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 204
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2319,
     "status": "ok",
     "timestamp": 1588652647308,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 420
    },
    "id": "WXeof-AB8BiV",
    "outputId": "668fec81-aa7c-4b57-96f3-28b766e2cf35",
    "pycharm": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time\tLoss\tLinear Loss\n",
      "0\t0.2444\t0.2444\n",
      "100\t0.1231\t0.1234\n",
      "200\t0.0854\t0.0855\n",
      "300\t0.0652\t0.0649\n",
      "400\t0.0523\t0.0519\n",
      "500\t0.0434\t0.0427\n",
      "600\t0.0367\t0.0359\n",
      "700\t0.0315\t0.0306\n",
      "800\t0.0273\t0.0263\n",
      "900\t0.0239\t0.0229\n"
     ]
    }
   ],
   "source": [
    "print ('Time\\tLoss\\tLinear Loss')\n",
    "\n",
    "X, Y = train['image'], train['label']\n",
    "\n",
    "predictions = predictor(training_steps, fx_train)\n",
    "\n",
    "for i in training_steps:\n",
    "  params = get_params(state)\n",
    "  state = opt_apply(i, grad_loss(params, X, Y), state)\n",
    "\n",
    "  if i % print_every == 0:\n",
    "    exact_loss = loss(f(params, X), Y)\n",
    "    linear_loss = loss(predictions[i], Y)\n",
    "    print('{}\\t{:.4f}\\t{:.4f}'.format(i, exact_loss, linear_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gx65YR3A8_yd",
    "pycharm": {}
   },
   "source": [
    "## Gradient Descent, Cross Entropy Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8jEb5V9C8_yd",
    "pycharm": {}
   },
   "source": [
    "Create a optimizer and initialize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VKfuj6O88_ye",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "opt_init, opt_apply, get_params = optimizers.sgd(learning_rate)\n",
    "state = opt_init(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hpWaHdvH8_yg",
    "pycharm": {}
   },
   "source": [
    "Create an Cross Entropy loss and a gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zQ03wQ7O8_yh",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "loss = lambda fx, y_hat: -jnp.mean(log_softmax(fx) * y_hat)\n",
    "grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WgS4k3878_yi",
    "pycharm": {}
   },
   "source": [
    "Create a Gradient Descent predictor and compute the function space values of the network at initialization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "h2uIi4mQ8_yi",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "predictor = nt.predict.gradient_descent(loss, g_dd, train['label'], learning_rate=learning_rate)\n",
    "fx_train = f(params, train['image'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tRh7Ur9Y8_yj",
    "pycharm": {}
   },
   "source": [
    "Train the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 204
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 4141,
     "status": "ok",
     "timestamp": 1588652652808,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 420
    },
    "id": "FnW6DNWf8_yj",
    "outputId": "93101ca1-bf59-4262-ae7c-a7fecfc1a7b2",
    "pycharm": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time\tLoss\tLinear Loss\n",
      "0\t0.1696\t0.1696\n",
      "100\t0.1497\t0.1493\n",
      "200\t0.1336\t0.1329\n",
      "300\t0.1204\t0.1195\n",
      "400\t0.1093\t0.1083\n",
      "500\t0.0998\t0.0987\n",
      "600\t0.0916\t0.0906\n",
      "700\t0.0845\t0.0835\n",
      "800\t0.0783\t0.0773\n",
      "900\t0.0728\t0.0719\n"
     ]
    }
   ],
   "source": [
    "print ('Time\\tLoss\\tLinear Loss')\n",
    "\n",
    "X, Y = train['image'], train['label']\n",
    "\n",
    "predictions = predictor(training_steps, fx_train)\n",
    "\n",
    "for i in training_steps:\n",
    "  params = get_params(state)\n",
    "  state = opt_apply(i, grad_loss(params, X, Y), state)\n",
    "\n",
    "  if i % print_every == 0:\n",
    "    t = i * learning_rate\n",
    "    exact_loss = loss(f(params, X), Y)\n",
    "    linear_loss = loss(predictions[i], Y)\n",
    "    print('{:.0f}\\t{:.4f}\\t{:.4f}'.format(i, exact_loss, linear_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vc2FaYtEDBJ_",
    "pycharm": {}
   },
   "source": [
    "## Momentum, Cross Entropy Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L4onegU1DBKA",
    "pycharm": {}
   },
   "source": [
    "Create a optimizer and initialize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cxoiw-DADBKB",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "mass = 0.9\n",
    "opt_init, opt_apply, get_params = momentum(learning_rate, mass)\n",
    "state = opt_init(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "63VJ8y9FDBKE",
    "pycharm": {}
   },
   "source": [
    "Create a Cross Entropy loss and a gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "e8SxBiZXDBKE",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "loss = lambda fx, y_hat: -jnp.mean(log_softmax(fx) * y_hat)\n",
    "grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "t7GiiW-LDBKI",
    "pycharm": {}
   },
   "source": [
    "Create a momentum predictor and initialize it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8fpKKqPaDBKJ",
    "pycharm": {}
   },
   "outputs": [],
   "source": [
    "predictor = nt.predict.gradient_descent(loss,\n",
    "    g_dd, train['label'], learning_rate=learning_rate, momentum=mass)\n",
    "fx_train = f(params, train['image'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jW9ws4fMDBKL",
    "pycharm": {}
   },
   "source": [
    "Train the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 204
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 5582,
     "status": "ok",
     "timestamp": 1588652659738,
     "user": {
      "displayName": "",
      "photoUrl": "",
      "userId": ""
     },
     "user_tz": 420
    },
    "id": "_pfseUitDBKM",
    "outputId": "e6c9403d-1152-4e84-8479-211a3089ddcb",
    "pycharm": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time\tLoss\tLinear Loss\n",
      "0\t0.0680\t0.0680\n",
      "100\t0.0399\t0.0401\n",
      "200\t0.0262\t0.0266\n",
      "300\t0.0191\t0.0195\n",
      "400\t0.0148\t0.0153\n",
      "500\t0.0120\t0.0126\n",
      "600\t0.0101\t0.0106\n",
      "700\t0.0086\t0.0092\n",
      "800\t0.0075\t0.0081\n",
      "900\t0.0067\t0.0072\n"
     ]
    }
   ],
   "source": [
    "print ('Time\\tLoss\\tLinear Loss')\n",
    "\n",
    "X, Y = train['image'], train['label']\n",
    "\n",
    "predictions = predictor(training_steps, fx_train)\n",
    "\n",
    "for i in training_steps:\n",
    "  params = get_params(state)\n",
    "  state = opt_apply(i, grad_loss(params, X, Y), state)\n",
    "\n",
    "  if i % print_every == 0:\n",
    "    exact_loss = loss(f(params, X), Y)\n",
    "    linear_loss = loss(predictions[i], Y)\n",
    "    print('{:.0f}\\t{:.4f}\\t{:.4f}'.format(i, exact_loss, linear_loss))\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/deepmind/dm_python:dm_notebook3",
    "kind": "private"
   },
   "name": "Function Space Linearization.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
